# Copyright 2025 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
"""Generate data for test."""
import numpy as np
import mindspore as ms
from mindformers.parallel_core.transformer_config import MLATransformerConfig


def get_init_params(config: MLATransformerConfig, seq_length=8, batch_size=2):
    """Generate SBND-format input tensors for FlashAttention."""
    np.random.seed(1)
    shape = (seq_length, batch_size, config.num_attention_heads, config.kv_channels)
    tril_mask = np.tril(np.ones((seq_length, seq_length), dtype=np.uint8))
    attn_mask = np.expand_dims(tril_mask, axis=0)
    attn_mask = np.expand_dims(attn_mask, axis=0)
    attn_mask = np.tile(attn_mask, (batch_size, 1, 1, 1))

    return {
        "query": ms.tensor(0.01 * np.random.randn(*shape), ms.bfloat16),
        "key": ms.tensor(0.01 * np.random.randn(*shape), ms.bfloat16),
        "value": ms.tensor(0.01 * np.random.randn(*shape), ms.bfloat16),
        "attention_mask": ms.tensor(attn_mask)
    }


def get_golden() -> dict[str, np.ndarray]:
    """Generate golden data for test."""
    output_none = np.array(
        [[[0.0071191280, -0.0037393840, 0.0030939605, 0.0015579209],
          [0.0027723690, 0.0003985921, 0.0046713981, -0.0016514927]],
         [[0.0074381339, -0.0024556217, 0.0022729747, 0.0017401231],
          [0.0035460258, 0.0006343311, 0.0040016561, -0.0031776987]],
         [[0.0078668073, -0.0032225507, 0.0025722331, 0.0008512360],
          [0.0037909255, -0.0006040527, 0.0054220967, 0.0010560921]],
         [[0.0072363974, -0.0094947694, 0.0021116324, 0.0013144725],
          [0.0050793588, -0.0004573374, 0.0067336187, 0.0041246340]],
         [[0.0113723259, -0.0093368348, 0.0019863774, 0.0027413736],
          [0.0051215966, -0.0000274144, 0.0056912727, 0.0047886940]],
         [[0.0061050267, -0.0045233550, 0.0062140538, -0.0003948410],
          [-0.0049599926, 0.0012023055, 0.0083176121, 0.0083136018]],
         [[-0.0011044702, -0.0061736205, 0.0056276112, 0.0024073708],
          [0.0028066507, -0.0007311270, 0.0116033861, 0.0036949271]],
         [[0.0047302670, -0.0021929659, 0.0024810496, 0.0006083162],
          [0.0008888312, 0.0010367325, 0.0050781416, -0.0022250463]]], dtype=np.float32)
    output_100 = np.array(
        [[[0.0071010306, -0.0037096450, 0.0030966743, 0.0015520635],
          [0.0026768572, 0.0004320236, 0.0047259643, -0.0015588405]],
         [[0.0074322768, -0.0024548194, 0.0022462667, 0.0017557940],
          [0.0035279170, 0.0006353507, 0.0040422347, -0.0031076455]],
         [[0.0078567080, -0.0032778489, 0.0025831857, 0.0008490211],
          [0.0038640257, -0.0006233052, 0.0054533295, 0.0010916297]],
         [[0.0072169090, -0.0094421245, 0.0020905426, 0.0013172552],
          [0.0051141712, -0.0004685296, 0.0067053633, 0.0040948242]],
         [[0.0113667455, -0.0093460968, 0.0019858768, 0.0027365647],
          [0.0051432583, -0.0000388935, 0.0057794126, 0.0048830863]],
         [[0.0061017280, -0.0045241099, 0.0061999038, -0.0003272247],
          [-0.0050335983, 0.0012206289, 0.0083167590, 0.0083147995]],
         [[-0.0011044702, -0.0061736205, 0.0056276112, 0.0024073708],
          [0.0028066507, -0.0007311270, 0.0116033861, 0.0036949271]],
         [[0.0047302670, -0.0021929659, 0.0024810496, 0.0006083162],
          [0.0008888312, 0.0010367325, 0.0050781416, -0.0022250463]]], dtype=np.float32)
    return {
        'None': output_none,
        '100.0': output_100,
    }


def get_gpu_datas() -> dict[str, np.ndarray]:
    """Generate gpu data for test."""
    output_none = np.array(
        [[[0.0071105957, -0.0037231445, 0.0030822754, 0.0015563965],
          [0.0027618408, 0.0003986359, 0.0046691895, -0.0016403198]],
         [[0.0074462891, -0.0024566650, 0.0022735596, 0.0017471313],
          [0.0035552979, 0.0006370544, 0.0040283203, -0.0031738281]],
         [[0.0078735352, -0.0032196045, 0.0025787354, 0.0008583069],
          [0.0037994385, -0.0006065369, 0.0054321289, 0.0010604858]],
         [[0.0072326660, -0.0094604492, 0.0021057129, 0.0013198853],
          [0.0050659180, -0.0004596710, 0.0067443848, 0.0041198730]],
         [[0.0113525391, -0.0093383789, 0.0019836426, 0.0027465820],
          [0.0051269531, -0.0000305176, 0.0057067871, 0.0047912598]],
         [[0.0061035156, -0.0045166016, 0.0062255859, -0.0003967285],
          [-0.0049438477, 0.0011978149, 0.0083007812, 0.0083007812]],
         [[-0.0011062622, -0.0061645508, 0.0056152344, 0.0024108887],
          [0.0028076172, -0.0007324219, 0.0115966797, 0.0036926270]],
         [[0.0047302246, -0.0021820068, 0.0024719238, 0.0006103516],
          [0.0008850098, 0.0010375977, 0.0050659180, -0.0022125244]]], dtype=np.float16)
    output_100 = np.array(
        [[[0.0070800781, -0.0036926270, 0.0030822754, 0.0015563965],
          [0.0026855469, 0.0004310608, 0.0047302246, -0.0015411377]],
         [[0.0074157715, -0.0024414062, 0.0022430420, 0.0017623901],
          [0.0035400391, 0.0006332397, 0.0040283203, -0.0031127930]],
         [[0.0078735352, -0.0032653809, 0.0025787354, 0.0008583069],
          [0.0038604736, -0.0006256104, 0.0054626465, 0.0010986328]],
         [[0.0072021484, -0.0094604492, 0.0020904541, 0.0013198853],
          [0.0050964355, -0.0004711151, 0.0067138672, 0.0040893555]],
         [[0.0113525391, -0.0093383789, 0.0019836426, 0.0027465820],
          [0.0051574707, -0.0000407696, 0.0057983398, 0.0049133301]],
         [[0.0061035156, -0.0045166016, 0.0061950684, -0.0003318787],
          [-0.0050048828, 0.0012130737, 0.0083007812, 0.0083007812]],
         [[-0.0011062622, -0.0061645508, 0.0056152344, 0.0024108887],
          [0.0028076172, -0.0007324219, 0.0115966797, 0.0036926270]],
         [[0.0047302246, -0.0021820068, 0.0024719238, 0.0006103516],
          [0.0008850098, 0.0010375977, 0.0050659180, -0.0022125244]]], dtype=np.float16)
    return {
        'None': output_none,
        '100.0': output_100,
    }


GOLDEN_DATA = get_golden()
GPU_DATA = get_gpu_datas()
